"""
用于从标注图像生成训练集和测试集，每行的格式为：
```
    gt_image_path,raw_image_path
```
均为相对于dataset_dir的路径
"""

import os
import random

cross_rate = 0.1
dataset_dir = "/disk527/sdb1/a804_cbf/datasets/collect_data"
sub_dir = map(lambda x: os.path.join(dataset_dir, x), os.listdir(dataset_dir))
sub_dir = filter(lambda x: os.path.isdir(x), sub_dir)
if os.path.exists(os.path.join(dataset_dir, "train.txt")):
    os.remove(os.path.join(dataset_dir, "train.txt"))
if os.path.exists(os.path.join(dataset_dir, "val.txt")):
    os.remove(os.path.join(dataset_dir, "val.txt"))

for label_dir in sub_dir:
    basename = os.path.basename(label_dir)
    with open(os.path.join(label_dir, "gt_segment.txt"), "r") as f:
        ind = map(
            lambda x: (
                os.path.join(basename, "gt_segment", x.strip() + ".png"),
                os.path.join(basename, "images", x.strip() + ".png"),
            ),
            f.readlines(),
        )
    with open(os.path.join(dataset_dir, "train.txt"), "a") as train_f:
        with open(os.path.join(dataset_dir, "val.txt"), "a") as val_f:
            for i, (gt, raw) in enumerate(ind):
                if cross_rate < random.random():
                    print(f"{gt},{raw}", file=train_f)
                else:
                    print(f"{gt},{raw}", file=val_f)
    print(f"{label_dir} done")
